import math
import random
import DeepCCG


class DeepCCG_reservoir(DeepCCG.DeepCCG):

    per_class_mem = {}

    def __init__(self, args, use_task_inc_loss=True, mem_size=1000, mem_batch_size=10, min_num_mem_points=1,
                 num_of_iter_new_mem=5, mem_lr=0.1, fit_mem_batch_size=64, diag_approx=True,  # 30  # 32
                 new_mem_diag_approx=False, s_coeff=1, fixed_var=True, calc_stat_batch_size=1028,  # 256
                 mem_select_KLD=True, store_old_means=False):
        super().__init__(args, use_task_inc_loss, mem_size, mem_batch_size, min_num_mem_points,
                         num_of_iter_new_mem, mem_lr, fit_mem_batch_size, diag_approx,
                         new_mem_diag_approx, s_coeff, fixed_var, calc_stat_batch_size,
                         mem_select_KLD, store_old_means)
        self.remaining_space = mem_size
        self.seen_count = 0
        self.next = 0
        self.w = 0
        self.full = False

    # calculates the new memory after training on a batch
    def calc_new_mem(self, X, Y):

        combined_index_mem = []
        for y in self.per_class_mem:
            combined_index_mem += [(i, y) for i in range(len(self.per_class_mem[y]))]

        batch = [(self.batch[0][i], self.batch[1][i].item(), self.task_id) for i in range(self.batch[0].shape[0])]
        if not self.full:
            can_remove_count = min(len(batch), self.remaining_space)
            for x, y, t in batch[:can_remove_count]:
                if y not in self.per_class_mem:
                    self.per_class_mem[y] = []
                self.per_class_mem[y].append((x, t))
            self.remaining_space -= can_remove_count
            self.seen_count += can_remove_count
            if self.remaining_space == 0:
                self.full = True
                self.w = math.exp(math.log(random.uniform(0, 1)) / self.mem_size)
                self.next = self.seen_count + math.floor(math.log(random.uniform(0, 1)) / math.log(1 - self.w)) + 1
            if can_remove_count == len(batch):
                return
            batch = batch[can_remove_count:]

        for x, y, t in batch:
            if y not in self.per_class_mem:
                classes = list(self.per_class_mem.keys())
                pop_class = random.sample([y for y in classes if len(self.per_class_mem[y]) > 1], 1)[0]
                pop_index = random.randint(0, len(self.per_class_mem[pop_class])-1)
                del self.per_class_mem[pop_class][pop_index]
                self.per_class_mem[y] = [(x, t)]

        #combined_index_mem = []
        #for y in self.per_class_mem:
        #    combined_index_mem += [(i, y) for i in range(len(self.per_class_mem[y]))]

        while self.next <= self.seen_count + len(batch) - 1:
            #j = random.randint(0, self.mem_size - 1)
            i, y = combined_index_mem[random.randint(0, self.mem_size - 1)]
            #while combined_index_mem[j][1] == y:
            #    combined_index_mem[j] = (combined_index_mem[j][0]-1, combined_index_mem[j][1])
            #    j += 1
            #del combined_index_mem[max(j-1, 0)]
            if len(self.per_class_mem[y]) > 1:
                del self.per_class_mem[y][i]
            if batch[self.next - self.seen_count][1] not in self.per_class_mem:
                self.per_class_mem[batch[self.next - self.seen_count][1]] = []
            self.per_class_mem[batch[self.next - self.seen_count][1]].append((batch[self.next - self.seen_count][0],
                                                                             batch[self.next - self.seen_count][2]))
            self.w = self.w * math.exp(math.log(random.uniform(0, 1)) / self.mem_size)
            self.next = self.next + math.floor(math.log(random.uniform(0, 1)) / math.log(1 - self.w)) + 1
            combined_index_mem = []
            for y in self.per_class_mem:
                combined_index_mem += [(i, y) for i in range(len(self.per_class_mem[y]))]

        self.seen_count += len(batch)
